Change from thread idx based predicate to warp idx based#5821
Change from thread idx based predicate to warp idx based#5821
Conversation
|
Review updated until commit 122ac43 Description
|
| Relevant files | |||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Error Handling
|
Test failures
-
(Medium, 5)
NVFuser internal assert: NumComputeWarps not constant in TmaPersistent testsTest Name GB200 Source TmaPersistentTestF.KernelReuse ❌ Link TmaPersistentTestP.TmaInnerPersistentRmsNorm/__bfloat_2048_1024 ❌ Link TmaPersistentTestP.TmaInnerPersistentRmsNorm/__bfloat_2048_2048 ❌ Link TmaPersistentTestP.TmaInnerPersistentRmsNorm/__bfloat_2048_4096 ❌ Link TmaPersistentTestP.TmaInnerPersistentRmsNorm/__bfloat_2048_8192 ❌ Link -
(Medium, 2)
Thunder nvFuser nanoGPT autograd scalar mismatch in thunder.tests.test_networksTest Name GB200 H100 Source thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32 ❌ ❌ -
(Medium, 2)
nvFuser warp_specialize codegen assertion failure in tests.python.direct.test_tutorialTest Name GB200 H100 Source tests.python.direct.test_tutorial.test_warp_specialized_circular_buffering_pointwise ❌ ❌
Greptile SummaryThis PR refactors warp-specialized kernel predicate generation to use uniform warp IDs instead of thread indices, providing significant performance improvements through compiler optimization. Key Changes:
Performance Impact:
Architecture: Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Compiler as NVIDIA Fuser Compiler
participant Alloc as AllocationInserter
participant PDM as ParallelDimensionMap
participant PC as PredicateCompute
participant CB as CircularBuffer
Compiler->>Alloc: Start allocation pass
activate Alloc
Alloc->>PDM: Check canUseWarpIdBasedPredicate()
activate PDM
PDM-->>Alloc: Returns true if consecutive warp IDs possible
deactivate PDM
alt Warp ID based predicates available
Alloc->>Alloc: Compute uniform warp ID<br/>(tid = tidx + tidy*bdimx + tidz*bdimx*bdimy)<br/>(warp_id = tid / 32)
Alloc->>Compiler: Store uniform_warp_id in GpuLower
Compiler->>PC: Generate predicates
activate PC
PC->>PC: selectFirstWarpElectSyncPredicate()
Note over PC: Use warp_id == 0
PC->>PC: createElectSyncPredicateAsync()
Note over PC: Use warp_id == num_compute_warps
PC->>PC: createMultipleExpressionElectSync()
Note over PC: Route to async/compute warp predicates
deactivate PC
Compiler->>CB: getAsyncWarpPredicate()
activate CB
CB->>PDM: Get getNumComputeWarps()
Note over CB: Use warp_id >= num_compute_warps
CB-->>Compiler: Return warp ID based predicate
deactivate CB
else No consecutive warp IDs (fallback)
Note over Alloc: Skip warp ID computation
Alloc-->>Compiler: uniform_warp_id = nullptr
Compiler->>PC: Generate predicates (fallback)
activate PC
PC->>PC: Use thread index based predicates<br/>(threadIdx.x < 32 for first warp)<br/>(linear_index >= computed_threshold for async)
deactivate PC
Compiler->>CB: getAsyncWarpPredicate (fallback)
activate CB
CB->>PDM: Use getWarpSpecializationPaddedVal()
Note over CB: Use parallel_index >= (raw - padding)
CB-->>Compiler: Return thread index based predicate
deactivate CB
end
Compiler->>Compiler: Generate optimized PTX/SASS<br/>with reduced WARPSYNC/VOTEU.ALL instructions
deactivate Alloc
|
|
!test |
csrc/predicate_compute.cpp
Outdated
| // Val* zero = IrBuilder::create<Val>(0L, PrimDataType::UInt64); | ||
| // Val* warp_size = IrBuilder::create<Val>(32L, PrimDataType::UInt64); | ||
|
|
||
| // const ParallelDimensionMap& pdim_map = | ||
| // GpuLower::current()->info().parallelDimensionMap(); | ||
| // Val* async_warp_thread_index = pdim_map.getLinearThreadIndexAsync(); | ||
| // Val* warp_id = | ||
| // SimplifyingIrBuilder::divExpr(async_warp_thread_index, warp_size); | ||
| // // TODO Only select first warp now | ||
| // Val* select_warp = SimplifyingIrBuilder::eqExpr(warp_id, zero); | ||
|
|
||
| // // Use elect-sync if available | ||
| // if (pdim_map.canUseElectSyncInAsyncWarp()) { | ||
| // return SimplifyingIrBuilder::logicalAndExpr( | ||
| // select_warp, createElectSyncExpr()); | ||
| // } | ||
|
|
||
| // // Warp Specialized ParallelType is ThreadIdx.x and it contains less than | ||
| // 32 | ||
| // // threads, so manually select first thread in warp. | ||
| // Val* thread_id = | ||
| // SimplifyingIrBuilder::modExpr(async_warp_thread_index, warp_size); | ||
| // Val* select_thread = SimplifyingIrBuilder::eqExpr(thread_id, zero); | ||
| // return SimplifyingIrBuilder::logicalAndExpr(select_warp, select_thread); |
There was a problem hiding this comment.
style: Remove commented-out legacy code
0acf1bd to
ded2693
Compare
|
!test |
|
!test |
|
!test |
|
!test |
|
|
||
| // Compute warp_id = tid / 32 | ||
| Val* warp_size = IrBuilder::create<Val>(32L, DataType::Index); | ||
| Val* warp_id = SimplifyingIrBuilder::divExpr(tid, warp_size); |
There was a problem hiding this comment.
It seems like there is nothing guaranteeing this will be warp-uniform since the compiler cannot know the block size so unless TIDz and TIDy are both >1 then it won't know that tid is the linear thread ID. So do we need to do a warp broadcast? See #2323.
There was a problem hiding this comment.
yes, we need something like
// __shfl_sync helps PTXAS prove that every thread in the warp has the same
// uniform warp id.
__device__ __forceinline__ uint32_t getUniformWarpId() {
const unsigned int tid = threadIdx.x + threadIdx.y * blockDim.x +
threadIdx.z * blockDim.x * blockDim.y;
const unsigned int warp_id = tid / 32;
return __shfl_sync(0xFFFFFFFF, warp_id, 0);
}
This PR is not ready yet.
Test:
NVFUSER_DUMP=ptx,sass,scheduler_params,launch_param,cuda_to_file ./test_nvfuser --gtest_filter=*ThunderRMSNormBwd*bfloat*16384* 2>&1 |tee new.logSASS Code Comparison: old.log vs new3.log
Summary
Use
UniformWarpId()instead ofthreadIdx.x/y/zfor warp specialization predicates.Key Differences
1. Warp ID Calculation
OLD (old.log):
threadIdx.xdirectly (stored in R5)NEW (new3.log):
R0 >> 5(equivalent totid / 32)SHFL.IDXinstruction to ensure all threads in warp have the same value2. Predicate Comparison
OLD (old.log):
threadIdx.x >= 256(0x100)NEW (new3.log):
UniformWarpId >= 8(0x8)3. Code Size
.L_x_42.L_x_444. Register Usage
Both versions use similar register allocation patterns, but:
5. Thread Index Reads
OLD (old.log):
NEW (new3.log):
tid = threadIdx.z * blockDim.y * blockDim.x + threadIdx.y * blockDim.x + threadIdx.xBenefits of the New Approach
SHFL.IDXinstruction explicitly proves to PTXAS that all threads in a warp have the same warp ID valuePerformance Impact
🚀 Instruction Count Improvements (MAJOR WINS!)
⚡ WARPSYNC Instruction Reduction
Where WARPSYNC was eliminated:
OLD version had extra WARPSYNC at:
NEW version simplified to:
💡 Why This Happens
The
SHFL.IDXinstruction explicitly proves to PTXAS that all threads in a warp have the same warp ID. This allows the compiler to:Eliminate VOTEU.ALL instructions (-8 instructions, -89%): When PTXAS knows warp ID is uniform, it doesn't need expensive warp-level ballot operations to check uniformity
ISETP→VOTEU.ALL→ Check resultEliminate redundant WARPSYNC (-2 instructions, -40%): When PTXAS knows values are uniform, it doesn't need as many WARPSYNC barriers before ELECT operations
Simplify predicates (-5 ISETP instructions): Fewer comparison operations needed when uniformity is established early
Optimize control flow: Better branch prediction and divergence handling
📊 Overall Performance Impact
Cost vs Benefit Analysis:
Added overhead:
SHF.R.U32.HI+SHFL.IDXRemoved expensive instructions:
Net Performance Gain:
🎯 Summary
This is a clear win! Trading 2-3 cycles for 100-200+ cycles saved, with the biggest gains from: